{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3E96e1UKQ8uR"
      },
      "source": [
        "# MoViNet Tutorial\n",
        "\n",
        "This notebook provides basic example code to build, run, and fine-tune [MoViNets (Mobile Video Networks)](https://arxiv.org/pdf/2103.11511.pdf).\n",
        "\n",
        "Pretrained models are provided by [TensorFlow Hub](https://tfhub.dev/google/collections/movinet/) and the [TensorFlow Model Garden](https://github.com/tensorflow/models/tree/master/official/projects/movinet), trained on [Kinetics 600](https://deepmind.com/research/open-source/kinetics) for video action classification. All Models use TensorFlow 2 with Keras for inference and training.\n",
        "\n",
        "The following steps will be performed:\n",
        "\n",
        "1. [Running base model inference with TensorFlow Hub](#scrollTo=6g0tuFvf71S9\u0026line=8\u0026uniqifier=1)\n",
        "2. [Running streaming model inference with TensorFlow Hub and plotting predictions](#scrollTo=ADrHPmwGcBZ5\u0026line=4\u0026uniqifier=1)\n",
        "3. [Exporting a streaming model to TensorFlow Lite for mobile](#scrollTo=W3CLHvubvdSI\u0026line=3\u0026uniqifier=1)\n",
        "4. [Fine-Tuning a base Model with the TensorFlow Model Garden](#scrollTo=_s-7bEoa3f8g\u0026line=11\u0026uniqifier=1)\n",
        "\n",
        "![jumping jacks plot](https://storage.googleapis.com/tf_model_garden/vision/movinet/artifacts/jumpingjacks_plot.gif)\n",
        "\n",
        "To generate video plots like the one above, see [section 2](#scrollTo=ADrHPmwGcBZ5\u0026line=4\u0026uniqifier=1)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8_oLnvJy7kz5"
      },
      "source": [
        "## Setup\n",
        "\n",
        "For inference on smaller models (A0-A2), CPU is sufficient for this Colab. For fine-tuning, it is recommended to run the models using GPUs.\n",
        "\n",
        "To select a GPU in Colab, select `Runtime \u003e Change runtime type \u003e Hardware accelerator \u003e GPU` dropdown in the top menu."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s3khsunT7kWa"
      },
      "outputs": [],
      "source": [
        "# Install packages\n",
        "\n",
        "# tf-models-official is the stable Model Garden package\n",
        "# tf-models-nightly includes latest changes\n",
        "!pip install -U -q \"tf-models-official\"\n",
        "\n",
        "\n",
        "# Install the mediapy package for visualizing images/videos.\n",
        "# See https://github.com/google/mediapy\n",
        "!command -v ffmpeg \u003e/dev/null || (apt update \u0026\u0026 apt install -y ffmpeg)\n",
        "!pip install -q mediapy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dI_1csl6Q-gH"
      },
      "outputs": [],
      "source": [
        "# Run imports\n",
        "import os\n",
        "import matplotlib as mpl\n",
        "import matplotlib.pyplot as plt\n",
        "import mediapy as media\n",
        "import numpy as np\n",
        "import PIL\n",
        "import pandas as pd\n",
        "import tensorflow as tf\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow_hub as hub\n",
        "import tqdm\n",
        "import absl.logging\n",
        "\n",
        "tf.get_logger().setLevel('ERROR')\n",
        "absl.logging.set_verbosity(absl.logging.ERROR)\n",
        "mpl.rcParams.update({\n",
        "    'font.size': 10,\n",
        "})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OnFqOXazoWgy"
      },
      "source": [
        "Run the cell below to define helper functions and create variables."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "dx55NK3ZoZeh"
      },
      "outputs": [],
      "source": [
        "#@title Run this cell to set up some helper code.\n",
        "\n",
        "# Download Kinetics 600 label map\n",
        "!wget https://raw.githubusercontent.com/tensorflow/models/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/kinetics_600_labels.txt -O labels.txt -q\n",
        "\n",
        "with tf.io.gfile.GFile('labels.txt') as f:\n",
        "  lines = f.readlines()\n",
        "  KINETICS_600_LABELS_LIST = [line.strip() for line in lines]\n",
        "  KINETICS_600_LABELS = tf.constant(KINETICS_600_LABELS_LIST)\n",
        "\n",
        "def get_top_k(probs, k=5, label_map=KINETICS_600_LABELS):\n",
        "  \"\"\"Outputs the top k model labels and probabilities on the given video.\"\"\"\n",
        "  top_predictions = tf.argsort(probs, axis=-1, direction='DESCENDING')[:k]\n",
        "  top_labels = tf.gather(label_map, top_predictions, axis=-1)\n",
        "  top_labels = [label.decode('utf8') for label in top_labels.numpy()]\n",
        "  top_probs = tf.gather(probs, top_predictions, axis=-1).numpy()\n",
        "  return tuple(zip(top_labels, top_probs))\n",
        "\n",
        "def predict_top_k(model, video, k=5, label_map=KINETICS_600_LABELS):\n",
        "  \"\"\"Outputs the top k model labels and probabilities on the given video.\"\"\"\n",
        "  outputs = model.predict(video[tf.newaxis])[0]\n",
        "  probs = tf.nn.softmax(outputs)\n",
        "  return get_top_k(probs, k=k, label_map=label_map)\n",
        "\n",
        "def load_movinet_from_hub(model_id, model_mode, hub_version=3):\n",
        "  \"\"\"Loads a MoViNet model from TF Hub.\"\"\"\n",
        "  hub_url = f'https://tfhub.dev/tensorflow/movinet/{model_id}/{model_mode}/kinetics-600/classification/{hub_version}'\n",
        "\n",
        "  encoder = hub.KerasLayer(hub_url, trainable=True)\n",
        "\n",
        "  inputs = tf.keras.layers.Input(\n",
        "      shape=[None, None, None, 3],\n",
        "      dtype=tf.float32)\n",
        "\n",
        "  if model_mode == 'base':\n",
        "    inputs = dict(image=inputs)\n",
        "  else:\n",
        "    # Define the state inputs, which is a dict that maps state names to tensors.\n",
        "    init_states_fn = encoder.resolved_object.signatures['init_states']\n",
        "    state_shapes = {\n",
        "        name: ([s if s \u003e 0 else None for s in state.shape], state.dtype)\n",
        "        for name, state in init_states_fn(tf.constant([0, 0, 0, 0, 3])).items()\n",
        "    }\n",
        "    states_input = {\n",
        "        name: tf.keras.Input(shape[1:], dtype=dtype, name=name)\n",
        "        for name, (shape, dtype) in state_shapes.items()\n",
        "    }\n",
        "\n",
        "    # The inputs to the model are the states and the video\n",
        "    inputs = {**states_input, 'image': inputs}\n",
        "\n",
        "  # Output shape: [batch_size, 600]\n",
        "  outputs = encoder(inputs)\n",
        "\n",
        "  model = tf.keras.Model(inputs, outputs)\n",
        "  model.build([1, 1, 1, 1, 3])\n",
        "\n",
        "  return model\n",
        "\n",
        "# Download example gif\n",
        "!wget https://github.com/tensorflow/models/raw/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/jumpingjack.gif -O jumpingjack.gif -q\n",
        "\n",
        "def load_gif(file_path, image_size=(224, 224)):\n",
        "  \"\"\"Loads a gif file into a TF tensor.\"\"\"\n",
        "  with tf.io.gfile.GFile(file_path, 'rb') as f:\n",
        "    video = tf.io.decode_gif(f.read())\n",
        "  video = tf.image.resize(video, image_size)\n",
        "  video = tf.cast(video, tf.float32) / 255.\n",
        "  return video\n",
        "\n",
        "def get_top_k_streaming_labels(probs, k=5, label_map=KINETICS_600_LABELS_LIST):\n",
        "  \"\"\"Returns the top-k labels over an entire video sequence.\n",
        "\n",
        "  Args:\n",
        "    probs: probability tensor of shape (num_frames, num_classes) that represents\n",
        "      the probability of each class on each frame.\n",
        "    k: the number of top predictions to select.\n",
        "    label_map: a list of labels to map logit indices to label strings.\n",
        "\n",
        "  Returns:\n",
        "    a tuple of the top-k probabilities, labels, and logit indices\n",
        "  \"\"\"\n",
        "  top_categories_last = tf.argsort(probs, -1, 'DESCENDING')[-1, :1]\n",
        "  categories = tf.argsort(probs, -1, 'DESCENDING')[:, :k]\n",
        "  categories = tf.reshape(categories, [-1])\n",
        "\n",
        "  counts = sorted([\n",
        "      (i.numpy(), tf.reduce_sum(tf.cast(categories == i, tf.int32)).numpy())\n",
        "      for i in tf.unique(categories)[0]\n",
        "  ], key=lambda x: x[1], reverse=True)\n",
        "\n",
        "  top_probs_idx = tf.constant([i for i, _ in counts[:k]])\n",
        "  top_probs_idx = tf.concat([top_categories_last, top_probs_idx], 0)\n",
        "  top_probs_idx = tf.unique(top_probs_idx)[0][:k+1]\n",
        "\n",
        "  top_probs = tf.gather(probs, top_probs_idx, axis=-1)\n",
        "  top_probs = tf.transpose(top_probs, perm=(1, 0))\n",
        "  top_labels = tf.gather(label_map, top_probs_idx, axis=0)\n",
        "  top_labels = [label.decode('utf8') for label in top_labels.numpy()]\n",
        "\n",
        "  return top_probs, top_labels, top_probs_idx\n",
        "\n",
        "def plot_streaming_top_preds_at_step(\n",
        "    top_probs,\n",
        "    top_labels,\n",
        "    step=None,\n",
        "    image=None,\n",
        "    legend_loc='lower left',\n",
        "    duration_seconds=10,\n",
        "    figure_height=500,\n",
        "    playhead_scale=0.8,\n",
        "    grid_alpha=0.3):\n",
        "  \"\"\"Generates a plot of the top video model predictions at a given time step.\n",
        "\n",
        "  Args:\n",
        "    top_probs: a tensor of shape (k, num_frames) representing the top-k\n",
        "      probabilities over all frames.\n",
        "    top_labels: a list of length k that represents the top-k label strings.\n",
        "    step: the current time step in the range [0, num_frames].\n",
        "    image: the image frame to display at the current time step.\n",
        "    legend_loc: the placement location of the legend.\n",
        "    duration_seconds: the total duration of the video.\n",
        "    figure_height: the output figure height.\n",
        "    playhead_scale: scale value for the playhead.\n",
        "    grid_alpha: alpha value for the gridlines.\n",
        "\n",
        "  Returns:\n",
        "    A tuple of the output numpy image, figure, and axes.\n",
        "  \"\"\"\n",
        "  num_labels, num_frames = top_probs.shape\n",
        "  if step is None:\n",
        "    step = num_frames\n",
        "\n",
        "  fig = plt.figure(figsize=(6.5, 7), dpi=300)\n",
        "  gs = mpl.gridspec.GridSpec(8, 1)\n",
        "  ax2 = plt.subplot(gs[:-3, :])\n",
        "  ax = plt.subplot(gs[-3:, :])\n",
        "\n",
        "  if image is not None:\n",
        "    ax2.imshow(image, interpolation='nearest')\n",
        "    ax2.axis('off')\n",
        "\n",
        "  preview_line_x = tf.linspace(0., duration_seconds, num_frames)\n",
        "  preview_line_y = top_probs\n",
        "\n",
        "  line_x = preview_line_x[:step+1]\n",
        "  line_y = preview_line_y[:, :step+1]\n",
        "\n",
        "  for i in range(num_labels):\n",
        "    ax.plot(preview_line_x, preview_line_y[i], label=None, linewidth='1.5',\n",
        "            linestyle=':', color='gray')\n",
        "    ax.plot(line_x, line_y[i], label=top_labels[i], linewidth='2.0')\n",
        "\n",
        "\n",
        "  ax.grid(which='major', linestyle=':', linewidth='1.0', alpha=grid_alpha)\n",
        "  ax.grid(which='minor', linestyle=':', linewidth='0.5', alpha=grid_alpha)\n",
        "\n",
        "  min_height = tf.reduce_min(top_probs) * playhead_scale\n",
        "  max_height = tf.reduce_max(top_probs)\n",
        "  ax.vlines(preview_line_x[step], min_height, max_height, colors='red')\n",
        "  ax.scatter(preview_line_x[step], max_height, color='red')\n",
        "\n",
        "  ax.legend(loc=legend_loc)\n",
        "\n",
        "  plt.xlim(0, duration_seconds)\n",
        "  plt.ylabel('Probability')\n",
        "  plt.xlabel('Time (s)')\n",
        "  plt.yscale('log')\n",
        "\n",
        "  fig.tight_layout()\n",
        "  fig.canvas.draw()\n",
        "\n",
        "  data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n",
        "  data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))\n",
        "  plt.close()\n",
        "\n",
        "  figure_width = int(figure_height * data.shape[1] / data.shape[0])\n",
        "  image = PIL.Image.fromarray(data).resize([figure_width, figure_height])\n",
        "  image = np.array(image)\n",
        "\n",
        "  return image, (fig, ax, ax2)\n",
        "\n",
        "def plot_streaming_top_preds(\n",
        "    probs,\n",
        "    video,\n",
        "    top_k=5,\n",
        "    video_fps=25.,\n",
        "    figure_height=500,\n",
        "    use_progbar=True):\n",
        "  \"\"\"Generates a video plot of the top video model predictions.\n",
        "\n",
        "  Args:\n",
        "    probs: probability tensor of shape (num_frames, num_classes) that represents\n",
        "      the probability of each class on each frame.\n",
        "    video: the video to display in the plot.\n",
        "    top_k: the number of top predictions to select.\n",
        "    video_fps: the input video fps.\n",
        "    figure_fps: the output video fps.\n",
        "    figure_height: the height of the output video.\n",
        "    use_progbar: display a progress bar.\n",
        "\n",
        "  Returns:\n",
        "    A numpy array representing the output video.\n",
        "  \"\"\"\n",
        "  video_fps = 8.\n",
        "  figure_height = 500\n",
        "  steps = video.shape[0]\n",
        "  duration = steps / video_fps\n",
        "\n",
        "  top_probs, top_labels, _ = get_top_k_streaming_labels(probs, k=top_k)\n",
        "\n",
        "  images = []\n",
        "  step_generator = tqdm.trange(steps) if use_progbar else range(steps)\n",
        "  for i in step_generator:\n",
        "    image, _ = plot_streaming_top_preds_at_step(\n",
        "        top_probs=top_probs,\n",
        "        top_labels=top_labels,\n",
        "        step=i,\n",
        "        image=video[i],\n",
        "        duration_seconds=duration,\n",
        "        figure_height=figure_height,\n",
        "    )\n",
        "    images.append(image)\n",
        "\n",
        "  return np.array(images)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6g0tuFvf71S9"
      },
      "source": [
        "## Running Base Model Inference with TensorFlow Hub\n",
        "\n",
        "We will load MoViNet-A2-Base from TensorFlow Hub as part of the [MoViNet collection](https://tfhub.dev/google/collections/movinet/).\n",
        "\n",
        "The following code will:\n",
        "\n",
        "- Load a MoViNet KerasLayer from [tfhub.dev](https://tfhub.dev).\n",
        "- Wrap the layer in a [Keras Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model).\n",
        "- Load an example gif as a video.\n",
        "- Classify the video and print the top-5 predicted classes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KZKKNZVBpglJ"
      },
      "outputs": [],
      "source": [
        "model = load_movinet_from_hub('a2', 'base', hub_version=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7kU1_pL10l0B"
      },
      "source": [
        "To provide a simple example video for classification, we can load a short gif of jumping jacks being performed.\n",
        "\n",
        "![jumping jacks](https://github.com/tensorflow/models/raw/f8af2291cced43fc9f1d9b41ddbf772ae7b0d7d2/official/projects/movinet/files/jumpingjack.gif)\n",
        "\n",
        "Attribution: Footage shared by [Coach Bobby Bluford](https://www.youtube.com/watch?v=-AxHpj-EuPg) on YouTube under the CC-BY license."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Iy0rKRrT723_"
      },
      "outputs": [],
      "source": [
        "video = load_gif('jumpingjack.gif', image_size=(172, 172))\n",
        "\n",
        "# Show video\n",
        "print(video.shape)\n",
        "media.show_video(video.numpy(), fps=5)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 17681,
          "status": "ok",
          "timestamp": 1674679874816,
          "user": {
            "displayName": "Siva Sravana Kumar Neeli",
            "userId": "06669604936988620923"
          },
          "user_tz": 480
        },
        "id": "P0bZfrAsqPv2",
        "outputId": "fe2074c1-684e-4973-af76-c6b6deee511b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1/1 [==============================] - 18s 18s/step\n",
            "jumping jacks 0.9166436\n",
            "zumba 0.016020758\n",
            "doing aerobics 0.008053949\n",
            "dancing charleston 0.006083598\n",
            "lunge 0.0035062768\n"
          ]
        }
      ],
      "source": [
        "# Run the model on the video and output the top 5 predictions\n",
        "outputs = predict_top_k(model, video)\n",
        "\n",
        "for label, prob in outputs:\n",
        "  print(label, prob)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ADrHPmwGcBZ5"
      },
      "source": [
        "## Run Streaming Model Inference with TensorFlow Hub and Plot Predictions\n",
        "\n",
        "We will load MoViNet-A2-Stream from TensorFlow Hub as part of the [MoViNet collection](https://tfhub.dev/google/collections/movinet/).\n",
        "\n",
        "The following code will:\n",
        "\n",
        "- Load a MoViNet model from [tfhub.dev](https://tfhub.dev).\n",
        "- Classify an example video and plot the streaming predictions over time."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tXWR13wthnK5"
      },
      "outputs": [],
      "source": [
        "model = load_movinet_from_hub('a2', 'stream', hub_version=3)\n",
        "\n",
        "# Create initial states for the stream model\n",
        "init_states_fn = model.layers[-1].resolved_object.signatures['init_states']\n",
        "init_states = init_states_fn(tf.shape(video[tf.newaxis]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 10662,
          "status": "ok",
          "timestamp": 1674679945388,
          "user": {
            "displayName": "Siva Sravana Kumar Neeli",
            "userId": "06669604936988620923"
          },
          "user_tz": 480
        },
        "id": "YqSkt7l8ltwt",
        "outputId": "cdee6358-cc78-48a8-dc0a-cb4c502a3368"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 13/13 [00:10\u003c00:00,  1.23it/s]"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "jumping jacks 0.9998122\n",
            "zumba 0.00011835461\n",
            "doing aerobics 3.3375778e-05\n",
            "dancing charleston 4.9820073e-06\n",
            "finger snapping 3.867353e-06\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        }
      ],
      "source": [
        "# Insert your video clip here\n",
        "video = load_gif('jumpingjack.gif', image_size=(172, 172))\n",
        "clips = tf.split(video[tf.newaxis], video.shape[0], axis=1)\n",
        "\n",
        "all_logits = []\n",
        "\n",
        "# To run on a video, pass in one frame at a time\n",
        "states = init_states\n",
        "for clip in tqdm.tqdm(clips):\n",
        "  # Input shape: [1, 1, 172, 172, 3]\n",
        "  logits, states = model.predict({**states, 'image': clip}, verbose=0)\n",
        "  all_logits.append(logits)\n",
        "\n",
        "logits = tf.concat(all_logits, 0)\n",
        "probs = tf.nn.softmax(logits)\n",
        "\n",
        "final_probs = probs[-1]\n",
        "top_k = get_top_k(final_probs)\n",
        "print()\n",
        "for label, prob in top_k:\n",
        "  print(label, prob)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 6945,
          "status": "ok",
          "timestamp": 1674679952309,
          "user": {
            "displayName": "Siva Sravana Kumar Neeli",
            "userId": "06669604936988620923"
          },
          "user_tz": 480
        },
        "id": "Xdox556CtMRb",
        "outputId": "41a242ce-87aa-430e-d75c-516935c1da3b"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 13/13 [00:06\u003c00:00,  1.90it/s]\n"
          ]
        }
      ],
      "source": [
        "# Generate a plot and output to a video tensor\n",
        "plot_video = plot_streaming_top_preds(probs, video, video_fps=8.)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NSStKE9klCs3"
      },
      "outputs": [],
      "source": [
        "# For gif format, set codec='gif'\n",
        "media.show_video(plot_video, fps=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W3CLHvubvdSI"
      },
      "source": [
        "## Export a Streaming Model to TensorFlow Lite for Mobile\n",
        "\n",
        "We will convert a MoViNet-A0-Stream model to [TensorFlow Lite](https://www.tensorflow.org/lite).\n",
        "\n",
        "The following code will:\n",
        "- Load a MoViNet-A0-Stream model.\n",
        "- Convert the model to TF Lite.\n",
        "- Run inference on an example video using the Python interpreter."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KH0j-07KVh06"
      },
      "outputs": [],
      "source": [
        "from official.projects.movinet.modeling import movinet\n",
        "from official.projects.movinet.modeling import movinet_model\n",
        "from official.projects.movinet.tools import export_saved_model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 15595,
          "status": "ok",
          "timestamp": 1674679969145,
          "user": {
            "displayName": "Siva Sravana Kumar Neeli",
            "userId": "06669604936988620923"
          },
          "user_tz": 480
        },
        "id": "5DGwH9qi87Oe",
        "outputId": "f28271e5-afc1-4fb2-9ad5-1e6fdd89a64f"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "movinet_a0_stream/\n",
            "movinet_a0_stream/ckpt-1.data-00000-of-00001\n",
            "movinet_a0_stream/ckpt-1.index\n",
            "movinet_a0_stream/checkpoint\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "\u003ctensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f0b0dc436a0\u003e"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "model_id = 'a0'\n",
        "use_positional_encoding = model_id in {'a3', 'a4', 'a5'}\n",
        "\n",
        "# Create backbone and model.\n",
        "backbone = movinet.Movinet(\n",
        "    model_id=model_id,\n",
        "    causal=True,\n",
        "    conv_type='2plus1d',\n",
        "    se_type='2plus3d',\n",
        "    activation='hard_swish',\n",
        "    gating_activation='hard_sigmoid',\n",
        "    use_positional_encoding=use_positional_encoding,\n",
        "    use_external_states=True,\n",
        ")\n",
        "\n",
        "model = movinet_model.MovinetClassifier(\n",
        "    backbone,\n",
        "    num_classes=600,\n",
        "    output_states=True)\n",
        "\n",
        "# Create your example input here.\n",
        "# Refer to the paper for recommended input shapes.\n",
        "inputs = tf.ones([1, 13, 172, 172, 3])\n",
        "\n",
        "# [Optional] Build the model and load a pretrained checkpoint.\n",
        "model.build(inputs.shape)\n",
        "\n",
        "\n",
        "# Extract pretrained weights\n",
        "!wget https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_stream.tar.gz -O movinet_a0_stream.tar.gz -q\n",
        "!tar -xvf movinet_a0_stream.tar.gz\n",
        "\n",
        "checkpoint_dir = 'movinet_a0_stream'\n",
        "checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)\n",
        "checkpoint = tf.train.Checkpoint(model=model)\n",
        "status = checkpoint.restore(checkpoint_path)\n",
        "status.assert_existing_objects_matched()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RLkV0xtPvfkY"
      },
      "outputs": [],
      "source": [
        "# Export to saved model\n",
        "saved_model_dir = 'model'\n",
        "tflite_filename = 'model.tflite'\n",
        "input_shape = [1, 1, 172, 172, 3]\n",
        "\n",
        "# Convert to saved model\n",
        "export_saved_model.export_saved_model(\n",
        "    model=model,\n",
        "    input_shape=input_shape,\n",
        "    export_path=saved_model_dir,\n",
        "    causal=True,\n",
        "    bundle_input_init_states_fn=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gPg_6eMC8IwF"
      },
      "outputs": [],
      "source": [
        "# Convert to TF Lite\n",
        "converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)\n",
        "tflite_model = converter.convert()\n",
        "\n",
        "with open(tflite_filename, 'wb') as f:\n",
        "  f.write(tflite_model)\n",
        "\n",
        "# Create the interpreter and signature runner\n",
        "interpreter = tf.lite.Interpreter(model_path=tflite_filename)\n",
        "runner = interpreter.get_signature_runner()\n",
        "\n",
        "init_states = {\n",
        "    name: tf.zeros(x['shape'], dtype=x['dtype'])\n",
        "    for name, x in runner.get_input_details().items()\n",
        "}\n",
        "del init_states['image']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 9,
          "status": "ok",
          "timestamp": 1674680160875,
          "user": {
            "displayName": "Siva Sravana Kumar Neeli",
            "userId": "06669604936988620923"
          },
          "user_tz": 480
        },
        "id": "-TQ-7oSJIlTA",
        "outputId": "2a7cf5f5-7648-44dd-a5d5-69da9ea82838"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\n",
            "jumping jacks 0.9733523\n",
            "jogging 0.0032490466\n",
            "stretching arm 0.002780116\n",
            "riding unicycle 0.0019377996\n",
            "passing soccer ball 0.0016310472\n"
          ]
        }
      ],
      "source": [
        "# Insert your video clip here\n",
        "video = load_gif('jumpingjack.gif', image_size=(172, 172))\n",
        "clips = tf.split(video[tf.newaxis], video.shape[0], axis=1)\n",
        "\n",
        "# To run on a video, pass in one frame at a time\n",
        "states = init_states\n",
        "for clip in clips:\n",
        "  # Input shape: [1, 1, 172, 172, 3]\n",
        "  outputs = runner(**states, image=clip)\n",
        "  logits = outputs.pop('logits')[0]\n",
        "  states = outputs\n",
        "\n",
        "probs = tf.nn.softmax(logits)\n",
        "top_k = get_top_k(probs)\n",
        "print()\n",
        "for label, prob in top_k:\n",
        "  print(label, prob)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_s-7bEoa3f8g"
      },
      "source": [
        "## Fine-Tune a Base Model with the TensorFlow Model Garden\n",
        "\n",
        "We will Fine-tune MoViNet-A0-Base on [UCF-101](https://www.crcv.ucf.edu/research/data-sets/ucf101/).\n",
        "\n",
        "The following code will:\n",
        "\n",
        "- Load the UCF-101 dataset with [TensorFlow Datasets](https://www.tensorflow.org/datasets/catalog/ucf101).\n",
        "- Create a simple [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) pipeline for training and evaluation.\n",
        "- Display some example videos from the dataset.\n",
        "- Build a MoViNet model and load pretrained weights.\n",
        "- Fine-tune the final classifier layers on UCF-101 and evaluate accuracy on the validation set."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o7unW4WVr580"
      },
      "source": [
        "### Load the UCF-101 Dataset with TensorFlow Datasets\n",
        "\n",
        "Calling `download_and_prepare()` will automatically download the dataset. This step may take up to 1 hour depending on the download and extraction speed. After downloading, the next cell will output information about the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2IHLbPAfrs5P"
      },
      "outputs": [],
      "source": [
        "# Run imports\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "from official.vision.configs import video_classification\n",
        "from official.projects.movinet.configs import movinet as movinet_configs\n",
        "from official.projects.movinet.modeling import movinet\n",
        "from official.projects.movinet.modeling import movinet_layers\n",
        "from official.projects.movinet.modeling import movinet_model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FxM1vNYp_YAM"
      },
      "outputs": [],
      "source": [
        "dataset_name = 'ucf101'\n",
        "\n",
        "builder = tfds.builder(dataset_name)\n",
        "\n",
        "config = tfds.download.DownloadConfig(verify_ssl=False)\n",
        "builder.download_and_prepare(download_config=config)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 7,
          "status": "ok",
          "timestamp": 1674680161043,
          "user": {
            "displayName": "Siva Sravana Kumar Neeli",
            "userId": "06669604936988620923"
          },
          "user_tz": 480
        },
        "id": "boQHbcfDhXpJ",
        "outputId": "e3e8b59d-861a-4358-f19a-99b059d7eb47"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Number of classes: 101\n",
            "Number of examples for train: 9537\n",
            "Number of examples for test: 3783\n",
            "\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "tfds.core.DatasetInfo(\n",
              "    name='ucf101',\n",
              "    full_name='ucf101/ucf101_1_256/2.0.0',\n",
              "    description=\"\"\"\n",
              "    A 101-label video classification dataset.\n",
              "    \"\"\",\n",
              "    config_description=\"\"\"\n",
              "    256x256 UCF with the first action recognition split.\n",
              "    \"\"\",\n",
              "    homepage='https://www.crcv.ucf.edu/data-sets/ucf101/',\n",
              "    data_path='~/tensorflow_datasets/ucf101/ucf101_1_256/2.0.0',\n",
              "    file_format=tfrecord,\n",
              "    download_size=6.48 GiB,\n",
              "    dataset_size=7.41 GiB,\n",
              "    features=FeaturesDict({\n",
              "        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=101),\n",
              "        'video': Video(Image(shape=(256, 256, 3), dtype=tf.uint8)),\n",
              "    }),\n",
              "    supervised_keys=None,\n",
              "    disable_shuffling=False,\n",
              "    splits={\n",
              "        'test': \u003cSplitInfo num_examples=3783, num_shards=32\u003e,\n",
              "        'train': \u003cSplitInfo num_examples=9537, num_shards=64\u003e,\n",
              "    },\n",
              "    citation=\"\"\"@article{DBLP:journals/corr/abs-1212-0402,\n",
              "      author    = {Khurram Soomro and\n",
              "                   Amir Roshan Zamir and\n",
              "                   Mubarak Shah},\n",
              "      title     = {{UCF101:} {A} Dataset of 101 Human Actions Classes From Videos in\n",
              "                   The Wild},\n",
              "      journal   = {CoRR},\n",
              "      volume    = {abs/1212.0402},\n",
              "      year      = {2012},\n",
              "      url       = {http://arxiv.org/abs/1212.0402},\n",
              "      archivePrefix = {arXiv},\n",
              "      eprint    = {1212.0402},\n",
              "      timestamp = {Mon, 13 Aug 2018 16:47:45 +0200},\n",
              "      biburl    = {https://dblp.org/rec/bib/journals/corr/abs-1212-0402},\n",
              "      bibsource = {dblp computer science bibliography, https://dblp.org}\n",
              "    }\"\"\",\n",
              ")"
            ]
          },
          "execution_count": 18,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "num_classes = builder.info.features['label'].num_classes\n",
        "num_examples = {\n",
        "    name: split.num_examples\n",
        "    for name, split in builder.info.splits.items()\n",
        "}\n",
        "\n",
        "print('Number of classes:', num_classes)\n",
        "print('Number of examples for train:', num_examples['train'])\n",
        "print('Number of examples for test:', num_examples['test'])\n",
        "print()\n",
        "\n",
        "builder.info"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9cO_BCu9le3r"
      },
      "outputs": [],
      "source": [
        "# Build the training and evaluation datasets.\n",
        "\n",
        "batch_size = 8\n",
        "num_frames = 8\n",
        "frame_stride = 10\n",
        "resolution = 172\n",
        "\n",
        "def format_features(features):\n",
        "  video = features['video']\n",
        "  video = video[:, ::frame_stride]\n",
        "  video = video[:, :num_frames]\n",
        "\n",
        "  video = tf.reshape(video, [-1, video.shape[2], video.shape[3], 3])\n",
        "  video = tf.image.resize(video, (resolution, resolution))\n",
        "  video = tf.reshape(video, [-1, num_frames, resolution, resolution, 3])\n",
        "  video = tf.cast(video, tf.float32) / 255.\n",
        "\n",
        "  label = tf.one_hot(features['label'], num_classes)\n",
        "  return (video, label)\n",
        "\n",
        "train_dataset = builder.as_dataset(\n",
        "    split='train',\n",
        "    batch_size=batch_size,\n",
        "    shuffle_files=True)\n",
        "train_dataset = train_dataset.map(\n",
        "    format_features,\n",
        "    num_parallel_calls=tf.data.AUTOTUNE)\n",
        "train_dataset = train_dataset.repeat()\n",
        "train_dataset = train_dataset.prefetch(2)\n",
        "\n",
        "test_dataset = builder.as_dataset(\n",
        "    split='test',\n",
        "    batch_size=batch_size)\n",
        "test_dataset = test_dataset.map(\n",
        "    format_features,\n",
        "    num_parallel_calls=tf.data.AUTOTUNE,\n",
        "    deterministic=True)\n",
        "test_dataset = test_dataset.prefetch(2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rToX7_Ymgh57"
      },
      "source": [
        "Display some example videos from the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KG8Z7rUj06of"
      },
      "outputs": [],
      "source": [
        "videos, labels = next(iter(train_dataset))\n",
        "media.show_videos(videos.numpy(), codec='gif', fps=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R3RHeuHdsd_3"
      },
      "source": [
        "### Build MoViNet-A0-Base and Load Pretrained Weights\n",
        "\n",
        "Here we create a MoViNet model using the open source code provided in [official/projects/movinet](https://github.com/tensorflow/models/tree/master/official/projects/movinet) and load the pretrained weights. Here we freeze the all layers except the final classifier head to speed up fine-tuning."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 6934,
          "status": "ok",
          "timestamp": 1674680181444,
          "user": {
            "displayName": "Siva Sravana Kumar Neeli",
            "userId": "06669604936988620923"
          },
          "user_tz": 480
        },
        "id": "JpfxpeGSsbzJ",
        "outputId": "83a49ab1-b28e-45c6-c0b3-2fc446944f65"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "movinet_a0_base/\n",
            "movinet_a0_base/checkpoint\n",
            "movinet_a0_base/ckpt-1.data-00000-of-00001\n",
            "movinet_a0_base/ckpt-1.index\n"
          ]
        }
      ],
      "source": [
        "model_id = 'a0'\n",
        "\n",
        "tf.keras.backend.clear_session()\n",
        "\n",
        "backbone = movinet.Movinet(model_id=model_id)\n",
        "model = movinet_model.MovinetClassifier(backbone=backbone, num_classes=600)\n",
        "model.build([1, 1, 1, 1, 3])\n",
        "\n",
        "# Load pretrained weights\n",
        "!wget https://storage.googleapis.com/tf_model_garden/vision/movinet/movinet_a0_base.tar.gz -O movinet_a0_base.tar.gz -q\n",
        "!tar -xvf movinet_a0_base.tar.gz\n",
        "\n",
        "checkpoint_dir = 'movinet_a0_base'\n",
        "checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)\n",
        "checkpoint = tf.train.Checkpoint(model=model)\n",
        "status = checkpoint.restore(checkpoint_path)\n",
        "status.assert_existing_objects_matched()\n",
        "\n",
        "def build_classifier(backbone, num_classes, freeze_backbone=False):\n",
        "  \"\"\"Builds a classifier on top of a backbone model.\"\"\"\n",
        "  model = movinet_model.MovinetClassifier(\n",
        "      backbone=backbone,\n",
        "      num_classes=num_classes)\n",
        "  model.build([batch_size, num_frames, resolution, resolution, 3])\n",
        "\n",
        "  if freeze_backbone:\n",
        "    for layer in model.layers[:-1]:\n",
        "      layer.trainable = False\n",
        "    model.layers[-1].trainable = True\n",
        "\n",
        "  return model\n",
        "\n",
        "# Wrap the backbone with a new classifier to create a new classifier head\n",
        "# with num_classes outputs (101 classes for UCF101).\n",
        "# Freeze all layers except for the final classifier head.\n",
        "model = build_classifier(backbone, num_classes, freeze_backbone=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ucntdu2xqgXB"
      },
      "source": [
        "Configure fine-tuning with training/evaluation steps, loss object, metrics, learning rate, optimizer, and callbacks.\n",
        "\n",
        "Here we use 3 epochs. Training for more epochs should improve accuracy."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WUYTw48BouTu"
      },
      "outputs": [],
      "source": [
        "num_epochs = 3\n",
        "\n",
        "train_steps = num_examples['train'] // batch_size\n",
        "total_train_steps = train_steps * num_epochs\n",
        "test_steps = num_examples['test'] // batch_size\n",
        "\n",
        "loss_obj = tf.keras.losses.CategoricalCrossentropy(\n",
        "    from_logits=True,\n",
        "    label_smoothing=0.1)\n",
        "\n",
        "metrics = [\n",
        "    tf.keras.metrics.TopKCategoricalAccuracy(\n",
        "        k=1, name='top_1', dtype=tf.float32),\n",
        "    tf.keras.metrics.TopKCategoricalAccuracy(\n",
        "        k=5, name='top_5', dtype=tf.float32),\n",
        "]\n",
        "\n",
        "initial_learning_rate = 0.01\n",
        "learning_rate = tf.keras.optimizers.schedules.CosineDecay(\n",
        "    initial_learning_rate, decay_steps=total_train_steps,\n",
        ")\n",
        "optimizer = tf.keras.optimizers.RMSprop(\n",
        "    learning_rate, rho=0.9, momentum=0.9, epsilon=1.0, clipnorm=1.0)\n",
        "\n",
        "model.compile(loss=loss_obj, optimizer=optimizer, metrics=metrics)\n",
        "\n",
        "callbacks = [\n",
        "    tf.keras.callbacks.TensorBoard(),\n",
        "]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0IyAOOlcpHna"
      },
      "source": [
        "Run the fine-tuning with Keras compile/fit. After fine-tuning the model, we should be able to achieve \u003e85% accuracy on the test set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "executionInfo": {
          "elapsed": 3426904,
          "status": "ok",
          "timestamp": 1674683608342,
          "user": {
            "displayName": "Siva Sravana Kumar Neeli",
            "userId": "06669604936988620923"
          },
          "user_tz": 480
        },
        "id": "Zecc_K3lga8I",
        "outputId": "1946f687-aece-49f5-f5e7-09aad9f9882b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/3\n",
            "1192/1192 [==============================] - 1151s 949ms/step - loss: 2.5097 - top_1: 0.6726 - top_5: 0.8745 - val_loss: 1.6358 - val_top_1: 0.8125 - val_top_5: 0.9666\n",
            "Epoch 2/3\n",
            "1192/1192 [==============================] - 1138s 951ms/step - loss: 1.3347 - top_1: 0.9062 - top_5: 0.9894 - val_loss: 1.4627 - val_top_1: 0.8400 - val_top_5: 0.9709\n",
            "Epoch 3/3\n",
            "1192/1192 [==============================] - 1138s 955ms/step - loss: 1.2301 - top_1: 0.9340 - top_5: 0.9943 - val_loss: 1.4386 - val_top_1: 0.8438 - val_top_5: 0.9751\n"
          ]
        }
      ],
      "source": [
        "results = model.fit(\n",
        "    train_dataset,\n",
        "    validation_data=test_dataset,\n",
        "    epochs=num_epochs,\n",
        "    steps_per_epoch=train_steps,\n",
        "    validation_steps=test_steps,\n",
        "    callbacks=callbacks,\n",
        "    validation_freq=1,\n",
        "    verbose=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XuH8XflmpU9d"
      },
      "source": [
        "We can also view the training and evaluation progress in TensorBoard."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 839
        },
        "executionInfo": {
          "elapsed": 33,
          "status": "ok",
          "timestamp": 1674683608343,
          "user": {
            "displayName": "Siva Sravana Kumar Neeli",
            "userId": "06669604936988620923"
          },
          "user_tz": 480
        },
        "id": "9fZhzhRJRd2J",
        "outputId": "43a8b75e-28f2-456c-b6d7-5a02b65d2443"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Reusing TensorBoard on port 43479 (pid 278134), started 19:51:44 ago. (Use '!kill 278134' to kill it.)"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "\n",
              "        (async () =\u003e {\n",
              "            const url = new URL(await google.colab.kernel.proxyPort(43479, {'cache': true}));\n",
              "            url.searchParams.set('tensorboardColab', 'true');\n",
              "            const iframe = document.createElement('iframe');\n",
              "            iframe.src = url;\n",
              "            iframe.setAttribute('width', '100%');\n",
              "            iframe.setAttribute('height', '800');\n",
              "            iframe.setAttribute('frameborder', 0);\n",
              "            document.body.appendChild(iframe);\n",
              "        })();\n",
              "    "
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript object\u003e"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "%reload_ext tensorboard\n",
        "%tensorboard --logdir logs --port 0"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "provenance": [
        {
          "file_id": "1nV2uiAZgRk2Ble2kximcRZvCSv9c02Xd",
          "timestamp": 1674684623688
        },
        {
          "file_id": "11msGCxFjxwioBOBJavP9alfTclUQCJf-",
          "timestamp": 1617043059980
        }
      ]
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
